{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training your own custom model using `solaris`\n",
    "\n",
    "If you want to go beyond using [the pretrained models in solaris](../../pretrained_models.html), you can train your own. Here's a primer for how to do so, where we'll walk through training the [SpaceNet 4 Baseline model](https://github.com/cosmiq/cosmiq_sn4_baseline) fresh. If you want to use one of the existing models in `solaris`, [check out this tutorial](api_training_spacenet.ipynb).\n",
    "\n",
    "First, you'll need to [create a YAML config file](creating_the_yaml_config_file.ipynb) for your model. This config should differ from a pre-trained model in a couple of key places:\n",
    "\n",
    "- model_name: Don't use one of the model names for a pre-trained model in solaris; give it another name.\n",
    "- model_path: If you have pre-trained weights to load in, put the path to those weights here; otherwise, leave it blank.\n",
    "\n",
    "Fill out all of the model-specific parameters (width/height of inputs, mask channels, the neural network framework, optimizer, learning rate, etc.) according to the model you plan to use.\n",
    "\n",
    "Next, you'll need to create your model. See below for the SpaceNet 4 Baseline example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose\n",
    "from tensorflow.keras.layers import concatenate, BatchNormalization, Dropout\n",
    "from tensorflow.keras import Model\n",
    "\n",
    "def cosmiq_sn4_baseline(input_shape=(512, 512, 3), base_depth=64):\n",
    "    \"\"\"Keras implementation of untrained TernausNet model architecture.\n",
    "\n",
    "    Arguments:\n",
    "    ----------\n",
    "    input_shape (3-tuple): a tuple defining the shape of the input image.\n",
    "    base_depth (int): the base convolution filter depth for the first layer\n",
    "        of the model. Must be divisible by two, as the final layer uses\n",
    "        base_depth/2 filters. The default value, 64, corresponds to the\n",
    "        original TernausNetV1 depth.\n",
    "\n",
    "    Returns:\n",
    "    --------\n",
    "    An uncompiled Keras Model instance with TernausNetV1 architecture.\n",
    "\n",
    "    \"\"\"\n",
    "    inputs = Input(input_shape)\n",
    "    conv1 = Conv2D(base_depth, 3, activation='relu', padding='same')(inputs)\n",
    "    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n",
    "\n",
    "    conv2_1 = Conv2D(base_depth*2, 3, activation='relu',\n",
    "                     padding='same')(pool1)\n",
    "    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2_1)\n",
    "\n",
    "    conv3_1 = Conv2D(base_depth*4, 3, activation='relu',\n",
    "                     padding='same')(pool2)\n",
    "    conv3_2 = Conv2D(base_depth*4, 3, activation='relu',\n",
    "                     padding='same')(conv3_1)\n",
    "    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3_2)\n",
    "\n",
    "    conv4_1 = Conv2D(base_depth*8, 3, activation='relu',\n",
    "                     padding='same')(pool3)\n",
    "    conv4_2 = Conv2D(base_depth*8, 3, activation='relu',\n",
    "                     padding='same')(conv4_1)\n",
    "    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4_2)\n",
    "\n",
    "    conv5_1 = Conv2D(base_depth*8, 3, activation='relu',\n",
    "                     padding='same')(pool4)\n",
    "    conv5_2 = Conv2D(base_depth*8, 3, activation='relu',\n",
    "                     padding='same')(conv5_1)\n",
    "    pool5 = MaxPooling2D(pool_size=(2, 2))(conv5_2)\n",
    "\n",
    "    conv6_1 = Conv2D(base_depth*8, 3, activation='relu',\n",
    "                     padding='same')(pool5)\n",
    "\n",
    "    up7 = Conv2DTranspose(base_depth*4, 2, strides=(2, 2), activation='relu',\n",
    "                          padding='same')(conv6_1)\n",
    "    concat7 = concatenate([up7, conv5_2])\n",
    "    conv7_1 = Conv2D(base_depth*8, 3, activation='relu',\n",
    "                     padding='same')(concat7)\n",
    "\n",
    "    up8 = Conv2DTranspose(base_depth*4, 2, strides=(2, 2), activation='relu',\n",
    "                          padding='same')(conv7_1)\n",
    "    concat8 = concatenate([up8, conv4_2])\n",
    "    conv8_1 = Conv2D(base_depth*8, 3, activation='relu',\n",
    "                     padding='same')(concat8)\n",
    "\n",
    "    up9 = Conv2DTranspose(base_depth*2, 2, strides=(2, 2), activation='relu',\n",
    "                          padding='same')(conv8_1)\n",
    "    concat9 = concatenate([up9, conv3_2])\n",
    "    conv9_1 = Conv2D(base_depth*4, 3, activation='relu',\n",
    "                     padding='same')(concat9)\n",
    "\n",
    "    up10 = Conv2DTranspose(base_depth, 2, strides=(2, 2), activation='relu',\n",
    "                           padding='same')(conv9_1)\n",
    "    concat10 = concatenate([up10, conv2_1])\n",
    "    conv10_1 = Conv2D(base_depth*2, 3, activation='relu',\n",
    "                      padding='same')(concat10)\n",
    "\n",
    "    up11 = Conv2DTranspose(int(base_depth/2), 2, strides=(2, 2),\n",
    "                           activation='relu', padding='same')(conv10_1)\n",
    "    concat11 = concatenate([up11, conv1])\n",
    "\n",
    "    out = Conv2D(1, 1, activation='sigmoid', padding='same')(concat11)\n",
    "\n",
    "    return Model(inputs=inputs, outputs=out)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, you'll pass that model to a custom model dictionary for the `solaris` model trainer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "custom_model_dict = {'model_name': 'cosmiq_sn4_baseline',\n",
    "                     'weight_path': None,\n",
    "                     'weight_url': None,\n",
    "                     'arch': cosmiq_sn4_baseline}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you can follow roughly the same process as for a pre-trained model: load in the config file, then create your trainer. The major difference here is that you'll pass an additional argument to the trainer, `custom_model_dict`, which provides the model architecture to the trainer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import solaris as sol\n",
    "\n",
    "config = sol.utils.config.parse('/Users/nweir/code/cosmiq_repos/solaris/cosmiq_sn4_baseline.yml')\n",
    "trainer = sol.nets.train.Trainer(config, custom_model_dict=custom_model_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point, you can treat training as you would a pre-trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "solaris",
   "language": "python",
   "name": "solaris"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
